import torch
import torch.nn as nn
import torch.nn.functional as F
from PFSPModel_Lib import AddAndInstanceNormalization, FeedForward, Norm_MixedScore_MultiHeadAttention

class PFSPModel(nn.Module):

    def __init__(self, **model_params):
        super().__init__()
        self.model_params = model_params
        self.encoder = PFSP_Encoder(**model_params)
        self.decoder = PFSP_Decoder(**model_params)

        self.encoded_row = None
        self.encoded_col = None
        self.placeholder = nn.Parameter(torch.Tensor(self.model_params['embedding_dim']))
        self.placeholder.data.uniform_(-1, 1)

    def pre_forward(self, reset_state, latent_var):

        problems = reset_state.problems # problems.shape: (batch, job, mc)
          
        batch_size = problems.size(0)
        job_cnt = problems.size(1)
        machine_cnt = problems.size(2)
        embedding_dim = self.model_params['embedding_dim']
        latent_dimension = self.model_params['latent_cont_size']+self.model_params['latent_disc_size']

        if self.training:
            pomo_size = self.model_params['pomo_size']
        else:
            pomo_size = self.model_params['eval_pomo_size']

        row_emb = torch.zeros(size=(batch_size, job_cnt, embedding_dim))
        # emb.shape: (batch, node, embedding)
        col_emb = torch.zeros(size=(batch_size, machine_cnt, embedding_dim))
        # shape: (batch, node, embedding)

        seed_cnt = self.model_params['one_hot_seed_cnt']
        rand = torch.rand(batch_size, seed_cnt)
        batch_rand_perm = rand.argsort(dim=1)
        rand_idx = batch_rand_perm[:, :machine_cnt]

        b_idx = torch.arange(batch_size)[:, None].expand(batch_size, machine_cnt)
        n_idx = torch.arange(machine_cnt)[None, :].expand(batch_size, machine_cnt)
        col_emb[b_idx, n_idx, rand_idx] = 1 # shape: (batch, node, embedding)

        self.encoded_row, self.encoded_col = self.encoder(row_emb, col_emb, problems) # encoded_nodes.shape: (batch, node, embedding)
        self.encoded_col = self.encoded_col.expand(batch_size, pomo_size, embedding_dim) # shape: (batch, pomo, embed)

        #################################################################
        latent_cond_encoded_row = self.encoded_row.repeat_interleave(pomo_size, dim=0) # shape: (batch*pomo, job, embedding)
        
        # latent embedding
        self.latent_vector = latent_var# shape: (batch, pomo, embedding)

        re_latent_emb = self.latent_vector.reshape(batch_size*pomo_size, 1, latent_dimension)
        row_latent_emb = re_latent_emb.expand(batch_size*pomo_size, job_cnt, latent_dimension) # shape: (batch*pomo, job, embed)

        # latent conditioned K, V
        latent_kv = torch.cat([latent_cond_encoded_row, row_latent_emb], dim=-1) # shape: (batch*pomo, job, embed)

        # Set kv
        self.decoder.set_kv(latent_kv) # shape: (batch*pomo, job, embed)

    def forward(self, state):

        batch_size = state.BATCH_IDX.size(0)
        pomo_size = state.BATCH_IDX.size(1)

        if state.current_node is None:
            last_step = self.placeholder[None, None, :].repeat(batch_size, pomo_size, 1)
        else:
            last_step = _get_encoding(self.encoded_row, state.current_node)
        remain_mean_embeddings = _get_unvisited_node(self.encoded_row, state.ninf_mask)

        all_job_probs= self.decoder(self.encoded_col, last_step, remain_mean_embeddings, self.latent_vector, ninf_mask=state.ninf_mask)

        if self.training or self.model_params['eval_type'] == 'softmax':
            with torch.no_grad():
                selected = all_job_probs.reshape(batch_size * pomo_size, -1).multinomial(1) \
                    .squeeze(dim=1).reshape(batch_size, pomo_size) # shape: (batch*pomo)
            prob = all_job_probs[state.BATCH_IDX, state.POMO_IDX, selected] \
                .reshape(batch_size, pomo_size) # shape: (batch, pomo)
        else:
            with torch.no_grad():
                selected = all_job_probs.argmax(dim=2)
            prob = None

        return selected, prob

def _get_encoding(encoded_nodes, node_index_to_pick):
    # encoded_nodes.shape: (batch, problem, embedding)
    # node_index_to_pick.shape: (batch, pomo)

    batch_size = node_index_to_pick.size(0)
    pomo_size = node_index_to_pick.size(1)
    embedding_dim = encoded_nodes.size(2)

    gathering_index = node_index_to_pick[:, :, None].expand(batch_size, pomo_size, embedding_dim)
    # shape: (batch, pomo, embedding)

    picked_nodes = encoded_nodes.gather(dim=1, index=gathering_index)
    # shape: (batch, pomo, embedding)

    return picked_nodes

def _get_unvisited_node(encoded_nodes, mask):
    unv_mask = (mask == 0).float()  # visited node 1, unvisited node 0

    unv_mask_expanded = unv_mask.unsqueeze(-1)  # shape: (batch, pomo, node, 1)
    unvisited_embeddings = encoded_nodes.unsqueeze(1) * unv_mask_expanded  # shape: (batch, pomo, node, embedding)

    unv_count = unv_mask.sum(dim=-1, keepdim=True)  # shape: (batch, pomo, 1)

    avg_unv_embeddings = unvisited_embeddings.sum(dim=-2) / unv_count  # shape: (batch, pomo, embedding)
    return avg_unv_embeddings

########################################
# ENCODER
########################################
class PFSP_Encoder(nn.Module):
    def __init__(self, **model_params):
        super().__init__()
        encoder_layer_num = model_params['encoder_layer_num']
        embedding_dim = model_params['embedding_dim']
        mc_cnt = model_params['mc_cnt']

        self.layers = nn.ModuleList([EncoderLayer(**model_params) for _ in range(encoder_layer_num)])
        self.W_mc = nn.Linear(mc_cnt*embedding_dim, embedding_dim)
    def forward(self, row_emb, col_emb, cost_mat):
        # col_emb.shape: (batch, col_cnt, embedding)
        # row_emb.shape: (batch, row_cnt, embedding)
        # cost_mat.shape: (batch, row_cnt, col_cnt)
        for layer in self.layers:
            row_emb, col_emb = layer(row_emb, col_emb, cost_mat)

        # Projection Pooling
        col_emb = col_emb.reshape(cost_mat.size(0), 1, cost_mat.size(2) * row_emb.size(2)).clone()
        col_emb = self.W_mc(col_emb)

        # Mean Pooling
        #col_emb = col_emb.mean(dim=1, keepdim=True)
        return row_emb, col_emb


class EncoderLayer(nn.Module):
    def __init__(self, **model_params):
        super().__init__()
        self.row_encoding_block = EncodingBlock(**model_params)
        self.col_encoding_block = EncodingBlock(**model_params)

    def forward(self, row_emb, col_emb, cost_mat):
        # row_emb.shape: (batch, row_cnt, embedding)
        # col_emb.shape: (batch, col_cnt, embedding)
        # cost_mat.shape: (batch, row_cnt, col_cnt)
        row_emb_out = self.row_encoding_block(row_emb, col_emb, cost_mat)
        col_emb_out = self.col_encoding_block(col_emb, row_emb, cost_mat.transpose(1, 2))
        return row_emb_out, col_emb_out

class EncodingBlock(nn.Module):
    def __init__(self, **model_params):
        super().__init__()
        self.model_params = model_params
        embedding_dim = self.model_params['embedding_dim']
        head_num = self.model_params['head_num']
        qkv_dim = self.model_params['qkv_dim']

        self.Wq = nn.Linear(embedding_dim, head_num * qkv_dim, bias=False)
        self.Wk = nn.Linear(embedding_dim, head_num * qkv_dim, bias=False)
        self.Wv = nn.Linear(embedding_dim, head_num * qkv_dim, bias=False)
        
        self.mixed_score_MHA = Norm_MixedScore_MultiHeadAttention(**model_params)
        self.multi_head_combine = nn.Linear(head_num * qkv_dim, embedding_dim)

        self.add_n_normalization_1 = AddAndInstanceNormalization(**model_params)
        self.feed_forward = FeedForward(**model_params)
        self.add_n_normalization_2 = AddAndInstanceNormalization(**model_params)

    def forward(self, row_emb, col_emb, cost_mat):
        # NOTE: row and col can be exchanged, if cost_mat.transpose(1,2) is used
        # input1.shape: (batch, row_cnt, embedding)
        # input2.shape: (batch, col_cnt, embedding)
        # cost_mat.shape: (batch, row_cnt, col_cnt)
        head_num = self.model_params['head_num']

        q = reshape_by_heads(self.Wq(row_emb), head_num=head_num)
        # q shape: (batch, head_num, row_cnt, qkv_dim)
        k = reshape_by_heads(self.Wk(col_emb), head_num=head_num)
        v = reshape_by_heads(self.Wv(col_emb), head_num=head_num)
        # kv shape: (batch, head_num, col_cnt, qkv_dim)

        out_concat = self.mixed_score_MHA(q, k, v, cost_mat)
        # shape: (batch, row_cnt, head_num*qkv_dim)

        multi_head_out = self.multi_head_combine(out_concat)
        # shape: (batch, row_cnt, embedding)

        out1 = self.add_n_normalization_1(row_emb, multi_head_out)
        out2 = self.feed_forward(out1)
        out3 = self.add_n_normalization_2(out1, out2)

        return out3
        # shape: (batch, row_cnt(or col_cnt), embedding)


########################################
# Decoder
########################################
class PFSP_Decoder(nn.Module):
    def __init__(self, **model_params):
        super().__init__()
        self.model_params = model_params
        embedding_dim = self.model_params['embedding_dim']
        self.temperature = self.model_params['temperature']
        head_num = self.model_params['head_num']
        qkv_dim = self.model_params['qkv_dim']
        latent_dim = self.model_params['latent_cont_size']+self.model_params['latent_disc_size']

        self.Wq_0 = nn.Linear(3*embedding_dim+latent_dim, head_num * qkv_dim, bias=False)
        self.Wk = nn.Linear(embedding_dim+latent_dim, head_num * qkv_dim, bias=False)
        self.Wv = nn.Linear(embedding_dim+latent_dim, head_num * qkv_dim, bias=False)
        self.Wp = nn.Linear(embedding_dim+latent_dim, head_num * qkv_dim, bias=False)
        #self.Wi = nn.Linear(5, head_num * qkv_dim, bias=False)

        self.multi_head_combine = nn.Linear(head_num * qkv_dim, embedding_dim)
        self.feed_forward = FeedForward(**model_params)

        self.k = None  # saved key, for multi-head attention
        self.v = None  # saved value, for multi-head_attention
        self.single_head_key = None  # saved key, for single-head attention
        self.q1 = None  # saved q1, for multi-head attention

    def set_kv(self, encoded_jobs):
        # encoded_jobs.shape: (batch, job, embedding)
        head_num = self.model_params['head_num']

        self.k = reshape_by_heads(self.Wk(encoded_jobs), head_num=head_num)
        self.v = reshape_by_heads(self.Wv(encoded_jobs), head_num=head_num)
        # shape: (batch*pomo, head_num, 1, qkv_dim)

        single_head_key = self.Wp(encoded_jobs)
        self.single_head_key = single_head_key.transpose(1, 2)
        # shape: (batch*pomo, embedding, job)

    def set_q1(self, encoded_q1):
        # encoded_q.shape: (batch, n, embedding)  # n can be 1 or pomo
        head_num = self.model_params['head_num']

        self.q1 = reshape_by_heads(encoded_q1, head_num=head_num)

    def forward(self,mc_embedding, current_job, unv_embeddings, latent_embedding, ninf_mask):
        # encoded_machine.shape: (batch, pomo, embedding)
        # ninf_mask.shape: (batch, pomo, job)
        batch_size = ninf_mask.size(0)
        pomo_size = ninf_mask.size(1)
        job_size = ninf_mask.size(-1)
        embed_size = mc_embedding.size(-1)
        head_num = self.model_params['head_num']
        
        #  Multi-Head Attention
        #######################################################
        state = torch.cat([current_job, unv_embeddings], dim=-1) 
        context_embedding = self.Wq_0(torch.cat([state, mc_embedding, latent_embedding], dim=-1)) #mc_embedding
        
        # tensor reshape
        reshaped_context_emb = context_embedding.reshape(batch_size*pomo_size,1,embed_size)
        ninf_mask = ninf_mask.reshape(batch_size*pomo_size,1, job_size)

        q = reshape_by_heads(reshaped_context_emb, head_num=head_num) # set query

        out_concat = self._multi_head_attention_for_decoder(q, self.k, self.v,
                                                            rank3_ninf_mask=ninf_mask)
        # shape: (batch*pomo, 1, head_num*qkv_dim)
        mh_atten_out = self.multi_head_combine(out_concat)
        # shape: (batch*pomo, 1, embedding)

        updated_context = self.feed_forward(mh_atten_out)
        pointer = mh_atten_out+updated_context
        #shape: (batch*pomo, 1, embedding)

        # Single-Head Attention, for probability calculation
        #######################################################
        score = torch.matmul(pointer, self.single_head_key)
        # shape: (batch*pomo, 1, job_cnt)

        sqrt_embedding_dim = self.model_params['sqrt_embedding_dim']
        score_scaled = score / sqrt_embedding_dim
        # shape: (batch*pomo, 1, job_cnt)

        logit_clipping = self.model_params['logit_clipping']
        score_clipped = logit_clipping * torch.tanh(score_scaled)

        score_masked = score_clipped + ninf_mask
        score_masked = score_masked.reshape(batch_size,pomo_size,job_size)
        #  shape: (batch, pomo, job)
        probs = F.softmax(score_masked, dim=2)

        return probs
    
    def _multi_head_attention_for_decoder(self, q, k, v, rank2_ninf_mask=None, rank3_ninf_mask=None):
        # q shape: (batch, head_num, n, qkv_dim)   : n can be either 1 or PROBLEM_SIZE
        # k,v shape: (batch, head_num, job_cnt+1, qkv_dim)
        # rank2_ninf_mask.shape: (batch, job_cnt+1)
        # rank3_ninf_mask.shape: (batch, n, job_cnt+1)

        batch_size = q.size(0)
        n = q.size(2) # n equal to POMO
        node_cnt  = k.size(2)

        head_num = self.model_params['head_num']
        qkv_dim = self.model_params['qkv_dim']
        sqrt_qkv_dim = self.model_params['sqrt_qkv_dim']

        score = torch.matmul(q, k.transpose(2, 3))
        # shape: (batch, head_num, n, job_cnt+1)

        score_scaled = score / sqrt_qkv_dim

        if rank2_ninf_mask is not None:
            score_scaled = score_scaled + rank2_ninf_mask[:, None, None, :].expand(batch_size, head_num, n, node_cnt)
        if rank3_ninf_mask is not None:
            score_scaled = score_scaled + rank3_ninf_mask[:, None, :, :].expand(batch_size, head_num, n, node_cnt)

        weights = nn.Softmax(dim=3)(score_scaled)
        # shape: (batch, head_num, n, job_cnt+1)

        out = torch.matmul(weights, v)
        # shape: (batch, head_num, n, qkv_dim)

        out_transposed = out.transpose(1, 2)
        # shape: (batch, n, head_num, qkv_dim)

        out_concat = out_transposed.reshape(batch_size, n, head_num * qkv_dim)
        # shape: (batch, n, head_num*qkv_dim)

        return out_concat
    
########################################
# NN SUB FUNCTIONS
########################################

def reshape_by_heads(qkv, head_num):
    # q.shape: (batch, n, head_num*key_dim)   : n can be either 1 or PROBLEM_SIZE

    batch_s = qkv.size(0)
    n = qkv.size(1)

    q_reshaped = qkv.reshape(batch_s, n, head_num, -1)
    # shape: (batch, n, head_num, key_dim)

    q_transposed = q_reshaped.transpose(1, 2)
    # shape: (batch, head_num, n, key_dim)

    return q_transposed